#!/usr/bin/env python3
# ==============================================================================
#
# Copyright (C) 2022 Sophgo Technologies Inc.  All rights reserved.
#
# TPU-MLIR is licensed under the 2-Clause BSD License except for the
# third-party components.
#
# ==============================================================================

import ast
import astunparse
import os, sys


class RewriteName(ast.NodeTransformer):
    """change function call to class.function call."""

    def __init__(self, class_name):
        self.class_name = class_name

    def visit_Call(self, node):
        """Rewrite function calls with class name."""
        if isinstance(node.func, ast.Name):
            return ast.Call(
                func=ast.Attribute(
                    value=ast.Name(id=self.class_name, ctx=ast.Load()),
                    attr=node.func.id,
                    ctx=ast.Load(),
                ),
                args=node.args,
                keywords=node.keywords,
            )
        return node


class RemoveImport(ast.NodeTransformer):
    def visit_Import(self, node):
        """Remove import statements."""
        return None

    def visit_ImportFrom(self, node):
        """Remove from-import statements."""
        return None


def find_classes_and_funcs(node):
    return (
        [f for f in ast.iter_child_nodes(node) if isinstance(f, ast.ClassDef)],
        [f for f in ast.iter_child_nodes(node) if isinstance(f, ast.FunctionDef)],
    )


def merge_files(python_files):
    classes = []

    for file in python_files:
        with open(file, "r") as f:
            tree = ast.parse(f.read())
            class_defs, func_defs = find_classes_and_funcs(tree)

            # Add static methods to classes
            for class_def in class_defs:
                # Decorate functions as static methods
                transformer = RewriteName(class_def.name)

                # Decorate functions as static methods
                static_func_defs = [
                    ast.FunctionDef(
                        name=func_def.name,
                        args=func_def.args,
                        body=[transformer.visit(node) for node in func_def.body],
                        decorator_list=[ast.Name(id="staticmethod", ctx=ast.Load())]
                        + func_def.decorator_list,
                        returns=func_def.returns,
                    )
                    for func_def in func_defs
                ]
                transformer = RemoveImport()
                class_def.body = [transformer.visit(node) for node in class_def.body]
                class_def.body.extend(static_func_defs)
                classes.append(class_def)

    module = ast.Module(body=classes, type_ignores=[])
    return astunparse.unparse(module)


top_import = """
# Automatically generated by the FlatBuffers compiler and merged by merge_pyfbs.py.
# Do not modify.

# $ flatc -o . --python xxx.fbs
# $ merge_pyfbs.py python/models out_merged.py

import flatbuffers
from flatbuffers.compat import import_numpy

np = import_numpy()
"""


if __name__ == "__main__":
    assert len(sys.argv) > 2 and "usage: merge_pyfbs.py ./python_fbs_folder output_file"
    directory = sys.argv[1]
    out_file = sys.argv[2]
    if len(sys.argv) > 3:  # import some modules
        top_import += "\n" + sys.argv[3] + "\n"

    python_files = [
        os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(".py")
    ]

    merged_content = merge_files(python_files)
    with open(out_file, "w") as f:
        f.write(top_import)
        f.write(merged_content)
